Library Imports

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

Template

spark = (
    SparkSession.builder
    .master("local")
    .appName("Section 3.2 - Range Join Conditions (WIP)")
    .config("spark.some.config.option", "some-value")
    .getOrCreate()
)


sc = spark.sparkContext
geo_loc_table = spark.createDataFrame([
    (1, 10, "foo"), 
    (11, 36, "bar"), 
    (37, 59, "baz"),
], ["ipstart", "ipend", "loc"])

geo_loc_table.toPandas()
ipstart ipend loc
0 1 10 foo
1 11 36 bar
2 37 59 baz
records_table = spark.createDataFrame([
    (1, 11), 
    (2, 38), 
    (3, 50),
],["id", "inet"])

records_table.toPandas()
id inet
0 1 11
1 2 38
2 3 50

Range Join Conditions

A naive approach (just specifying this as the range condition) would result in a full cartesian product and a filter that enforces the condition (tested using Spark 2.0). This has a horrible effect on performance, especially if DataFrames are more than a few hundred thousands records.

source: http://zachmoshe.com/2016/09/26/efficient-range-joins-with-spark.html

The source of the problem is pretty simple. When you execute join and join condition is not equality based the only thing that Spark can do right now is expand it to Cartesian product followed by filter what is pretty much what happens inside BroadcastNestedLoopJoin

source: https://stackoverflow.com/questions/37953830/spark-sql-performance-join-on-value-between-min-and-max?answertab=active#tab-top

Option #1

join_condition = [
    records_table['inet'] >= geo_loc_table['ipstart'],
    records_table['inet'] <= geo_loc_table['ipend'],
]

df = records_table.join(geo_loc_table, join_condition, "left")

df.toPandas()
id inet ipstart ipend loc
0 1 11 11 36 bar
1 2 38 37 59 baz
2 3 50 37 59 baz
df.explain()
== Physical Plan ==
BroadcastNestedLoopJoin BuildRight, LeftOuter, ((inet#252L >= ipstart#245L) && (inet#252L <= ipend#246L))
:- Scan ExistingRDD[id#251L,inet#252L]
+- BroadcastExchange IdentityBroadcastMode
   +- Scan ExistingRDD[ipstart#245L,ipend#246L,loc#247]

Option #2

from bisect import bisect_right
from pyspark.sql.functions import udf
from pyspark.sql.types import LongType

geo_start_bd = spark.sparkContext.broadcast(map(lambda x: x.ipstart, geo_loc_table
    .select("ipstart")
    .orderBy("ipstart")
    .collect()
))

def find_le(x):
    'Find rightmost value less than or equal to x'
    i = bisect_right(geo_start_bd.value, x)
    if i:
        return geo_start_bd.value[i-1]
    return None

records_table_with_ipstart = records_table.withColumn(
    "ipstart", udf(find_le, LongType())("inet")
)

df = records_table_with_ipstart.join(geo_loc_table, ["ipstart"], "left")

df.toPandas()
ipstart id inet ipend loc
0 37 2 38 59 baz
1 37 3 50 59 baz
2 11 1 11 36 bar
df.explain()
== Physical Plan ==
*(4) Project [ipstart#272L, id#251L, inet#252L, ipend#246L, loc#247]
+- SortMergeJoin [ipstart#272L], [ipstart#245L], LeftOuter
   :- *(2) Sort [ipstart#272L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(ipstart#272L, 200)
   :     +- *(1) Project [id#251L, inet#252L, pythonUDF0#281L AS ipstart#272L]
   :        +- BatchEvalPython [find_le(inet#252L)], [id#251L, inet#252L, pythonUDF0#281L]
   :           +- Scan ExistingRDD[id#251L,inet#252L]
   +- *(3) Sort [ipstart#245L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(ipstart#245L, 200)
         +- Scan ExistingRDD[ipstart#245L,ipend#246L,loc#247]

results matching ""

    No results matching ""